{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "Classification of Diseased and Healthy 3D Coronary Artery Shapes Using Tenforflow (Minimal Reproducible Example)"
      ],
      "metadata": {
        "id": "YlZXKxuzBnsE"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i6eIqm2s-4aU"
      },
      "outputs": [],
      "source": [
        "!pip install MedShapeNetCore\n",
        "!pip install tensorflow[and-cuda]"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "download the dataset"
      ],
      "metadata": {
        "id": "Pqmt6lkGCDib"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python -m MedShapeNetCore download ASOCA"
      ],
      "metadata": {
        "id": "uDB_WpSuCFjL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "import the necessary packages"
      ],
      "metadata": {
        "id": "wZ3yJ9fACJ7-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "from MedShapeNetCore.MedShapeNetCore import MyDict,MSNLoader,MSNVisualizer,MSNSaver,MSNTransformer"
      ],
      "metadata": {
        "id": "M1ug7iwFAykf"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "AEqNljrQAySX"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "load and prepare the dataset"
      ],
      "metadata": {
        "id": "cnNpSBj7CRTa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "msn_loader=MSNLoader()\n",
        "ASOCA_DATA=msn_loader.load('ASOCA')\n",
        "shape_data=ASOCA_DATA['mask']\n",
        "shape_labels=ASOCA_DATA['labels']\n",
        "print(shape_data.shape)\n",
        "print(shape_labels)\n",
        "x_train=np.expand_dims(shape_data, axis=4)\n",
        "y_train=shape_labels"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uSPHjq0RCUi1",
        "outputId": "68734775-c197-48fe-cff7-cfaf3aa21e50"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "current dataset: ./medshapenetcore_npz/medshapenetcore_ASOCA.npz\n",
            "available keys in the dataset: ['mask', 'point', 'mesh', 'labels']\n",
            "(40, 256, 256, 256)\n",
            "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.\n",
            " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "construct a classification model"
      ],
      "metadata": {
        "id": "WSiGRWtMDEr3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def get_model(width=256, height=256, depth=256):\n",
        "    \"\"\"Build a 3D convolutional neural network model.\"\"\"\n",
        "\n",
        "    inputs = keras.Input((width, height, depth, 1))\n",
        "\n",
        "    x = layers.Conv3D(filters=64, kernel_size=3, activation=\"relu\",padding='same',strides=(2, 2, 2))(inputs)\n",
        "    #x = layers.MaxPool3D(pool_size=2)(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "\n",
        "    x = layers.Conv3D(filters=64, kernel_size=3, activation=\"relu\",padding='same')(x)\n",
        "    #x = layers.MaxPool3D(pool_size=2)(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "\n",
        "    x = layers.Conv3D(filters=128, kernel_size=3, activation=\"relu\",padding='same',strides=(2, 2, 2))(x)\n",
        "    #x = layers.MaxPool3D(pool_size=2)(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "\n",
        "\n",
        "    x = layers.Conv3D(filters=128, kernel_size=3, activation=\"relu\",padding='same')(x)\n",
        "    #x = layers.MaxPool3D(pool_size=2)(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "\n",
        "\n",
        "    x = layers.Conv3D(filters=64, kernel_size=3, activation=\"relu\",padding='same')(x)\n",
        "    #x = layers.MaxPool3D(pool_size=2)(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "\n",
        "    x = layers.GlobalAveragePooling3D()(x)\n",
        "    x = layers.Dense(units=64, activation=\"relu\")(x)\n",
        "    x = layers.Dropout(0.4)(x)\n",
        "\n",
        "    outputs = layers.Dense(units=1, activation=\"sigmoid\")(x)\n",
        "\n",
        "    # Define the model.\n",
        "    model = keras.Model(inputs, outputs, name=\"3dcnn\")\n",
        "    return model\n",
        "\n",
        "\n",
        "# Build model.\n",
        "model = get_model(width=256, height=256, depth=256)\n",
        "model.summary()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-ZTDax_hDJy5",
        "outputId": "f3d94722-a620-4364-884f-cd2d2b27b335"
      },
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model: \"3dcnn\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " input_3 (InputLayer)        [(None, 256, 256, 256,    0         \n",
            "                             1)]                                 \n",
            "                                                                 \n",
            " conv3d_10 (Conv3D)          (None, 128, 128, 128, 6   1792      \n",
            "                             4)                                  \n",
            "                                                                 \n",
            " batch_normalization_10 (Ba  (None, 128, 128, 128, 6   256       \n",
            " tchNormalization)           4)                                  \n",
            "                                                                 \n",
            " conv3d_11 (Conv3D)          (None, 128, 128, 128, 6   110656    \n",
            "                             4)                                  \n",
            "                                                                 \n",
            " batch_normalization_11 (Ba  (None, 128, 128, 128, 6   256       \n",
            " tchNormalization)           4)                                  \n",
            "                                                                 \n",
            " conv3d_12 (Conv3D)          (None, 64, 64, 64, 128)   221312    \n",
            "                                                                 \n",
            " batch_normalization_12 (Ba  (None, 64, 64, 64, 128)   512       \n",
            " tchNormalization)                                               \n",
            "                                                                 \n",
            " conv3d_13 (Conv3D)          (None, 64, 64, 64, 128)   442496    \n",
            "                                                                 \n",
            " batch_normalization_13 (Ba  (None, 64, 64, 64, 128)   512       \n",
            " tchNormalization)                                               \n",
            "                                                                 \n",
            " conv3d_14 (Conv3D)          (None, 64, 64, 64, 64)    221248    \n",
            "                                                                 \n",
            " batch_normalization_14 (Ba  (None, 64, 64, 64, 64)    256       \n",
            " tchNormalization)                                               \n",
            "                                                                 \n",
            " global_average_pooling3d_2  (None, 64)                0         \n",
            "  (GlobalAveragePooling3D)                                       \n",
            "                                                                 \n",
            " dense_4 (Dense)             (None, 64)                4160      \n",
            "                                                                 \n",
            " dropout_2 (Dropout)         (None, 64)                0         \n",
            "                                                                 \n",
            " dense_5 (Dense)             (None, 1)                 65        \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 1003521 (3.83 MB)\n",
            "Trainable params: 1002625 (3.82 MB)\n",
            "Non-trainable params: 896 (3.50 KB)\n",
            "_________________________________________________________________\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "compile and train model"
      ],
      "metadata": {
        "id": "N-buEVTlDVUo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# compile model\n",
        "initial_learning_rate = 0.0001\n",
        "lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n",
        "    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True\n",
        ")\n",
        "model.compile(\n",
        "    loss=\"binary_crossentropy\",\n",
        "    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),\n",
        "    metrics=[\"acc\"],\n",
        ")\n",
        "checkpoint_cb = keras.callbacks.ModelCheckpoint(\n",
        "    \"3d_image_classification.h5\", save_best_only=False\n",
        ")\n",
        "early_stopping_cb = keras.callbacks.EarlyStopping(monitor=\"val_acc\", patience=15)\n",
        "epochs = 200\n",
        "# training\n",
        "\n",
        "model.fit(\n",
        "    x_train,\n",
        "    y_train,\n",
        "    validation_split=0.20,\n",
        "    epochs=epochs,\n",
        "    shuffle=True,\n",
        "    verbose=1,\n",
        "    callbacks=[checkpoint_cb, early_stopping_cb],\n",
        "    #callbacks=[checkpoint_cb],\n",
        ")"
      ],
      "metadata": {
        "id": "QvJkK2T2DXgZ"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}
